{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Multi-Label Classification Tutorial\n",
    "\n",
    "This tutorial shows how to use Tribuo's MultiLabel package to perform [multi-label classification](https://en.wikipedia.org/wiki/Multi-label_classification) tasks. Multi-label classification is the task of assigning a *set* of labels to a given example from a specific label domain, as opposed to multi-class classification which is assigning a *single* label to a given example.\n",
    "\n",
    "Tribuo provides linear model and factorization machine algorithms for native multi-label prediction, along with ensemble methods that either predict each label independently or as part of a [classifier chain](http://www.cs.waikato.ac.nz/~ml/publications/2009/chains.pdf), using any of Tribuo's classification algorithms as the base learners. Both the linear models, factorization machines and the `IndependentMultiLabelTrainer` use the *Binary Relevance* approach to multi-label prediction, where each label is predicted independently. The `ClassifierChainTrainer` and `CCEnsembleTrainer` use classifier chains which incorporate label structure into the prediction. In this tutorial we'll cover loading in multi-label data, performing predictions using several binary relevance based classifiers along with some classifier chains, and finally we'll evaluate the multi-label models using Tribuo's multi-label evaluation package.\n",
    "\n",
    "## Setup\n",
    "\n",
    "First you'll need a copy of the multi-label yeast dataset (we'll download these from the [LibSVM](https://www.csie.ntu.edu.tw/~cjlin/libsvm/) dataset repo):\n",
    "\n",
    "```\n",
    "wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_train.svm.bz2\n",
    "wget https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multilabel/yeast_test.svm.bz2\n",
    "```\n",
    "\n",
    "Then you should extract it using your preferred method. On macOS and Linux you can use `bunzip2`, and on Windows there are several packages which can extract bz2 files (e.g., [7-zip](https://www.7-zip.org/)).\n",
    "\n",
    "This dataset has 14 labels which represent different functional groups and the task is to predict the functional groups a gene belongs in based on micro-array expression measurements. Fortunately we don't need a PhD in Genetics to use this dataset as a benchmark, though obviously domain knowledge would be critical if we wanted to actually deploy any model based on this data.\n",
    "\n",
    "Now we'll load in the necessary jars and import some packages from the JDK and Tribuo."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%jars ./tribuo-multilabel-sgd-4.3.0-jar-with-dependencies.jar\n",
    "%jars ./tribuo-classification-experiments-4.3.0-jar-with-dependencies.jar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.nio.file.Paths;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import org.tribuo.*;\n",
    "import org.tribuo.classification.Label;\n",
    "import org.tribuo.classification.dtree.CARTClassificationTrainer;\n",
    "import org.tribuo.classification.dtree.impurity.*;\n",
    "import org.tribuo.datasource.*;\n",
    "import org.tribuo.math.optimisers.*;\n",
    "import org.tribuo.multilabel.*;\n",
    "import org.tribuo.multilabel.baseline.*;\n",
    "import org.tribuo.multilabel.ensemble.*;\n",
    "import org.tribuo.multilabel.evaluation.*;\n",
    "import org.tribuo.multilabel.sgd.linear.*;\n",
    "import org.tribuo.multilabel.sgd.objectives.*;\n",
    "import org.tribuo.util.Util;"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Loading the data\n",
    "\n",
    "There are two main forms for multi-label data in columnar representations. Either the dataset stores the labels in a single column using some delimiter (e.g., \"first_label,third_label\"), resulting in a sparse representation of the labels, or each label is stored in it's own column with a flag representing if that label is present (e.g., \"TRUE\" or \"1\"), resulting in a dense representation of the labels. Tribuo can load both formats, though currently the `MultiLabelFactory` only supports comma separated labels when parsing inputs directly from a `String`. When processing multi-label values through a `RowProcessor` then the factory receives a `List<String>` and processes each separate label appropriately.\n",
    "\n",
    "The yeast dataset we downloaded is stored in libsvm format which uses a sparse representation of the labels, so we'll use Tribuo's `LibSVMDataSource` to load it in, and process the outputs through a `MultiLabelFactory`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training data size = 1500, number of features = 103, number of classes = 14\n",
      "Testing data size = 917, number of features = 103, number of classes = 14\n"
     ]
    }
   ],
   "source": [
    "var factory = new MultiLabelFactory();\n",
    "var trainSource = new LibSVMDataSource<>(Paths.get(\".\",\"yeast_train.svm\"),factory);\n",
    "var testSource = new LibSVMDataSource<>(Paths.get(\".\",\"yeast_test.svm\"),factory,trainSource.isZeroIndexed(),trainSource.getMaxFeatureID());\n",
    "var train = new MutableDataset<>(trainSource);\n",
    "var test = new MutableDataset<>(testSource);\n",
    "System.out.println(String.format(\"Training data size = %d, number of features = %d, number of classes = %d\",train.size(),train.getFeatureMap().size(),train.getOutputInfo().size()));\n",
    "System.out.println(String.format(\"Testing data size = %d, number of features = %d, number of classes = %d\",test.size(),test.getFeatureMap().size(),test.getOutputInfo().size()));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In Tribuo we represent a multi-label task using the `org.tribuo.multilabel.MultiLabel` output, which internally uses a set of `org.tribuo.classification.Label` objects to store the individual labels. This means that unlike most Tribuo prediction type packages, `tribuo-multilabel-core` depends on another output core package `tribuo-classification-core`. `MultiLabel` is a sparse representation of the labels, only the `Label`s which are active are stored in the `MultiLabel` object.\n",
    "\n",
    "We can inspect the first output from the training dataset to see this:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First output = (LabelSet={2,3})\n",
      "Second output = (LabelSet={11,12,6,7})\n"
     ]
    }
   ],
   "source": [
    "System.out.println(\"First output = \" + train.getExample(0).getOutput());\n",
    "System.out.println(\"Second output = \" + train.getExample(1).getOutput());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This first example is tagged with labels 2 & 3, and the second one is tagged with 6, 7, 11 and 12. Unfortunately the LibSVM format we loaded in uses numbers rather than names for the labels, but if there are more descriptive names present when the data is loaded in then those would be used as the label names.\n",
    "\n",
    "## The task of multi-label prediction\n",
    "\n",
    "Multi-label problems can be approached in several different ways. A common approach is to treat each label as an independent function of the input features, this leads to the *binary relevance* approach where each label is independent from each other, and multi-label classification can be thought of as a set of standard binary classification problems. This approach scales well, but if there is underlying structure in the label space (e.g., the label \"human\" implies the label \"animal\", but the label \"animal\" does not imply \"human\", so they are not independent), then this approach ignores useful information from the training data and may underperform more complicated approaches. Another popular way to convert a multi-label problem into a standard classification problem is via a *label powerset*, where each unique combination of the individual labels is treated as a single label in a large multi-class classification problem. While this allows the learning algorithm to fully capture any interactions between the labels, the label powerset is exponential in the number of labels, which rapidly makes this approach intractable as the number of labels increases though it can be useful in small label spaces. Tribuo currently focuses on binary relevance and other approaches which don't require exponential computation, though we're happy to discuss incorporating label powerset methods if people have need for them. \n",
    "\n",
    "## Training Binary Relevance models\n",
    "\n",
    "Now we'll train a few different binary relevance models (i.e., independent predictions of each label). First we'll use Tribuo's multi-label `LinearSGDModel` which natively makes multi-label predictions, then we'll wrap a binary classification decision tree into a multi-label predictor using `IndependentMultiLabelTrainer` and `IndependentMultiLabelModel`. Note: Tribuo has three classes called `LinearSGDModel`, one each for `Label`, `MultiLabel`, and `Regressor`, so the `LinearSGDModel` used in this tutorial is `org.tribuo.multilabel.sgd.linear.LinearSGDModel`, and the one used in the *multi-class* classification tutorials is `org.tribuo.classification.sgd.linear.LinearSGDModel`.\n",
    "\n",
    "Tribuo's multi-label SGD package supports two different objective functions, Binary Cross-Entropy and Hinge loss. The BCE loss produces probabilitistic outputs thresholded at 0.5, whereas the hinge loss produces scores thresholded at 0. As Tribuo usually produces scores for each possible label, these thresholds are used to determine when a particular label is present in the `MultiLabel` object representing the predicted label set. As you may have seen in other tutorials, Tribuo uses stochastic gradient descent to fit it's linear models, so we'll define a gradient optimizer along with the loss function."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "var linTrainer = new LinearSGDTrainer(new BinaryCrossEntropy(),new AdaGrad(0.1,0.1),5,1000,1,Trainer.DEFAULT_SEED);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We train the model the same way we train the rest of Tribuo's models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Linear model training took (00:00:00:192)\n"
     ]
    }
   ],
   "source": [
    "var linStartTime = System.currentTimeMillis();\n",
    "var linModel = linTrainer.train(train);\n",
    "var linEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Linear model training took \" + Util.formatDuration(linStartTime,linEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tribuo doesn't have a native implementation of a multi-label decision tree, but it does have a multi-class decision tree, which we can convert into a multi-label predictor using `IndependentMultiLabelTrainer`. Now let's train a model using a decision tree to predict each label independently. First we define the binary classification trainer, then we'll use `IndependentMultiLabelTrainer` to wrap that `Trainer<Label>` and convert it into a `Trainer<MultiLabel>`, before training as usual."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Tree model training took (00:00:03:188)\n"
     ]
    }
   ],
   "source": [
    "Trainer<Label> treeTrainer = new CARTClassificationTrainer(6,10,0.0f,1.0f,false,new Entropy(),1L);\n",
    "Trainer<MultiLabel> dtTrainer = new IndependentMultiLabelTrainer(treeTrainer);\n",
    "var dtStartTime = System.currentTimeMillis();\n",
    "var dtModel = dtTrainer.train(train);\n",
    "var dtEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Tree model training took \" + Util.formatDuration(dtStartTime,dtEndTime));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've now got two different models, so let's measure their performance.\n",
    "\n",
    "## Evaluating multi-class problems\n",
    "\n",
    "Multi-label problems have many evaluation options available, as many standard classification evaluation measures like accuracy, precision, recall and F1 can be applied at the label level, and there are also many set level metric such as the [Jaccard Index](https://en.wikipedia.org/wiki/Jaccard_index) which can be used to compare the predicted label set against the ground truth one. In Tribuo we have access to most of the metrics available for multi-class classification problems, and v4.2 began adding set level metrics as well. If there are useful metrics that aren't implemented in Tribuo raise an issue on Tribuo's [Github page](https://github.com/oracle/tribuo).\n",
    "\n",
    "If you want to use the predicted scores for each of the labels separately (e.g., to analyse the model's performance) then as usual the `Map<String,MultiLabel>` available from `Prediction.getOutputScores()` has the full distribution. This map behaves slightly counter-intuitively, as each value is a `MultiLabel` object containing a single `Label`, and the key is the output of `Label.toString()`. This allows the labels to be inspected individually, but it is a little uncomfortable if you're used to working with a multi-label specific API. However it maintains conformity across all of Tribuo's different prediction APIs, both for predictions and evaluations, which makes it easier to incorporate lots of ML models into a larger system.\n",
    "\n",
    "We use the same evaluation paradigm as other Tribuo prediction tasks, first we construct an `Evaluator` and then feed it a model and some test data to produce an `Evaluation` which contains the appropriate performance metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "var eval = new MultiLabelEvaluator();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we'll look at the linear model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Linear model evaluation took (00:00:00:063)\n",
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "(LabelSet={12})               683         677           6         230       0.991       0.746       0.852\n",
      "(LabelSet={13})                13           0          13           0       0.000       0.000       0.000\n",
      "(LabelSet={0})                286         131         155          45       0.458       0.744       0.567\n",
      "(LabelSet={1})                393         162         231         130       0.412       0.555       0.473\n",
      "(LabelSet={2})                385         231         154          95       0.600       0.709       0.650\n",
      "(LabelSet={3})                330         169         161          79       0.512       0.681       0.585\n",
      "(LabelSet={4})                281         106         175          35       0.377       0.752       0.502\n",
      "(LabelSet={5})                219          26         193          14       0.119       0.650       0.201\n",
      "(LabelSet={6})                167           0         167           0       0.000       0.000       0.000\n",
      "(LabelSet={7})                191           0         191           0       0.000       0.000       0.000\n",
      "(LabelSet={8})                 80           0          80           0       0.000       0.000       0.000\n",
      "(LabelSet={9})                 92           0          92           0       0.000       0.000       0.000\n",
      "(LabelSet={10})                91           0          91           0       0.000       0.000       0.000\n",
      "(LabelSet={11})               688         684           4         226       0.994       0.752       0.856\n",
      "Total                       3,899       2,186       1,713         854\n",
      "Accuracy                                                                    0.561\n",
      "Micro Average                                                               0.561       0.719       0.630\n",
      "Macro Average                                                               0.319       0.399       0.335\n",
      "Balanced Error Rate                                                         0.681\n",
      "Jaccard Score                                                               0.497\n"
     ]
    }
   ],
   "source": [
    "var linTStartTime = System.currentTimeMillis();\n",
    "var linEval = eval.evaluate(linModel,test);\n",
    "var linTEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Linear model evaluation took \" + Util.formatDuration(linTStartTime,linTEndTime));\n",
    "System.out.println(linEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, the decision tree:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Tree model evaluation took (00:00:00:094)\n",
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "(LabelSet={12})               683         607          76         201       0.889       0.751       0.814\n",
      "(LabelSet={13})                13           0          13           2       0.000       0.000       0.000\n",
      "(LabelSet={0})                286         111         175          98       0.388       0.531       0.448\n",
      "(LabelSet={1})                393         187         206         181       0.476       0.508       0.491\n",
      "(LabelSet={2})                385         251         134         193       0.652       0.565       0.606\n",
      "(LabelSet={3})                330         131         199          67       0.397       0.662       0.496\n",
      "(LabelSet={4})                281          92         189          41       0.327       0.692       0.444\n",
      "(LabelSet={5})                219          88         131         189       0.402       0.318       0.355\n",
      "(LabelSet={6})                167          30         137          41       0.180       0.423       0.252\n",
      "(LabelSet={7})                191          29         162          46       0.152       0.387       0.218\n",
      "(LabelSet={8})                 80           1          79          12       0.013       0.077       0.022\n",
      "(LabelSet={9})                 92          14          78          35       0.152       0.286       0.199\n",
      "(LabelSet={10})                91          20          71          83       0.220       0.194       0.206\n",
      "(LabelSet={11})               688         633          55         202       0.920       0.758       0.831\n",
      "Total                       3,899       2,194       1,705       1,391\n",
      "Accuracy                                                                    0.563\n",
      "Micro Average                                                               0.563       0.612       0.586\n",
      "Macro Average                                                               0.369       0.439       0.384\n",
      "Balanced Error Rate                                                         0.631\n",
      "Jaccard Score                                                               0.439\n"
     ]
    }
   ],
   "source": [
    "var dtTStartTime = System.currentTimeMillis();\n",
    "var dtEval = eval.evaluate(dtModel,test);\n",
    "var dtTEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Tree model evaluation took \" + Util.formatDuration(dtTStartTime,dtTEndTime));\n",
    "System.out.println(dtEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the native multi-label linear model outperformed the wrapped decision tree in terms of Jaccard Score, though the picture is more mixed in the other metrics, and the linear model is ignoring some of the labels.\n",
    "\n",
    "Unfortunately some of the metrics we might like to examine for regular multi-class classification aren't as easy to use in the multi-label case. For example, a multi-class confusion matrix has no direct analogue in the multi-label case, as there could be an arbitrary number of labels predicted for each output, meaning there is no notion of a label being mispredicted as another label. This means a multi-label confusion matrix is best presented as a series of binary confusion matrices, one per label. This tends to take up a lot of space, so we'll skip inspecting them in this tutorial, though they are accessible on the `MultiLabelEvaluation` object.\n",
    "\n",
    "Now let's look at a more complicated multi-label classification approach, using *Classifier Chains*.\n",
    "\n",
    "## Training Classifier Chains\n",
    "\n",
    "A [classifier chain](http://www.cs.waikato.ac.nz/~ml/publications/2009/chains.pdf) is similar to a binary relevance model, except there is a sequential order to the predictions (forming a chain), and each member of the chain receives extra features in the form of the predictions of earlier members of the chain. This means that if the chain is correctly ordered according to the causal structure of the labels (which is tricky to do) then it can start with the most independent label first, and then predict each label in sequence so it can use the earlier predictions to improve predictions for each subsequent label (e.g., we could predict if the example is an \"animal\" first, and then when we come to predict if it's a \"human\" we know that humans are animals making the prediction task easier).\n",
    "\n",
    "In practice we don't usually know the correct ordering of the labels as the causal structure is unknown, and if we supply the incorrect structure then we can reduce performance back to the level of the binary relevance models. Fortunately in Machine Learning we have a trick we can use when we need to deal with uncertain data, which is to randomize it many times, and take an average. So we could take many different classifier chains each with an random label order, and then each chain votes on the labels that should be predicted. This improves statistical performance over a single chain with a random order, and over a single chain with a poorly chosen order, though it's unlikely to beat a single classifier chain with the correct label ordering (if such an ordering exists). Unfortunately the classifier chain ensemble is more expensive computationally than the single chain, which is already relatively expensive compared to a single classifier like `LinearSGDModel`, but the chains can be straightforwardly parallelized (and we'll add support for this to a future version of Tribuo).\n",
    "\n",
    "We're going to use a single classifier chain with a random order, and then an ensemble of 20 classifier chains each using random orders to see how they perform."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "var ccTrainer = new ClassifierChainTrainer(treeTrainer,1L);\n",
    "var ccEnsembleTrainer = new CCEnsembleTrainer(treeTrainer,20,1L);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First we'll train and evaluate the single chain:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Classifier Chain model training took (00:00:02:893)\n",
      "Classifier Chain model evaluation took (00:00:00:153)\n",
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "(LabelSet={12})               683         616          67         203       0.902       0.752       0.820\n",
      "(LabelSet={13})                13           0          13           2       0.000       0.000       0.000\n",
      "(LabelSet={0})                286         159         127         172       0.556       0.480       0.515\n",
      "(LabelSet={1})                393         215         178         213       0.547       0.502       0.524\n",
      "(LabelSet={2})                385         251         134         193       0.652       0.565       0.606\n",
      "(LabelSet={3})                330         199         131         151       0.603       0.569       0.585\n",
      "(LabelSet={4})                281         112         169         124       0.399       0.475       0.433\n",
      "(LabelSet={5})                219          74         145         116       0.338       0.389       0.362\n",
      "(LabelSet={6})                167          39         128          48       0.234       0.448       0.307\n",
      "(LabelSet={7})                191          40         151          51       0.209       0.440       0.284\n",
      "(LabelSet={8})                 80           0          80          14       0.000       0.000       0.000\n",
      "(LabelSet={9})                 92           7          85          30       0.076       0.189       0.109\n",
      "(LabelSet={10})                91           8          83          32       0.088       0.200       0.122\n",
      "(LabelSet={11})               688         615          73         192       0.894       0.762       0.823\n",
      "Total                       3,899       2,335       1,564       1,541\n",
      "Accuracy                                                                    0.599\n",
      "Micro Average                                                               0.599       0.602       0.601\n",
      "Macro Average                                                               0.393       0.412       0.392\n",
      "Balanced Error Rate                                                         0.607\n",
      "Jaccard Score                                                               0.473\n"
     ]
    }
   ],
   "source": [
    "// train the model\n",
    "var ccStartTime = System.currentTimeMillis();\n",
    "var ccModel = ccTrainer.train(train);\n",
    "var ccEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Classifier Chain model training took \" + Util.formatDuration(ccStartTime,ccEndTime));\n",
    "\n",
    "// evaluate the model\n",
    "var ccTStartTime = System.currentTimeMillis();\n",
    "var ccEval = eval.evaluate(ccModel,test);\n",
    "var ccTEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Classifier Chain model evaluation took \" + Util.formatDuration(ccTStartTime,ccTEndTime));\n",
    "System.out.println(ccEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see the classifier chain improved over the binary relevance model when using trees as the base learner, and took roughly the same amount of time to train and evaluate. It's still not quite up to the linear model, but let's try the chain ensemble and see how it does.\n",
    "\n",
    "Now we'll train and evaluate the ensemble:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Classifier Chain Ensemble model training took (00:00:54:230)\n",
      "Classifier Chain Ensemble model evaluation took (00:00:02:249)\n",
      "Class                           n          tp          fn          fp      recall        prec          f1\n",
      "(LabelSet={12})               683         629          54         216       0.921       0.744       0.823\n",
      "(LabelSet={13})                13           0          13           1       0.000       0.000       0.000\n",
      "(LabelSet={0})                286         112         174          64       0.392       0.636       0.485\n",
      "(LabelSet={1})                393         170         223         140       0.433       0.548       0.484\n",
      "(LabelSet={2})                385         254         131         146       0.660       0.635       0.647\n",
      "(LabelSet={3})                330         194         136         111       0.588       0.636       0.611\n",
      "(LabelSet={4})                281         112         169          45       0.399       0.713       0.511\n",
      "(LabelSet={5})                219          49         170          40       0.224       0.551       0.318\n",
      "(LabelSet={6})                167          14         153           6       0.084       0.700       0.150\n",
      "(LabelSet={7})                191          14         177          19       0.073       0.424       0.125\n",
      "(LabelSet={8})                 80           0          80           0       0.000       0.000       0.000\n",
      "(LabelSet={9})                 92           0          92           4       0.000       0.000       0.000\n",
      "(LabelSet={10})                91           2          89           2       0.022       0.500       0.042\n",
      "(LabelSet={11})               688         640          48         205       0.930       0.757       0.835\n",
      "Total                       3,899       2,190       1,709         999\n",
      "Accuracy                                                                    0.562\n",
      "Micro Average                                                               0.562       0.687       0.618\n",
      "Macro Average                                                               0.337       0.489       0.359\n",
      "Balanced Error Rate                                                         0.663\n",
      "Jaccard Score                                                               0.485\n"
     ]
    }
   ],
   "source": [
    "// train the model\n",
    "var ccEnsembleStartTime = System.currentTimeMillis();\n",
    "var ccEnsembleModel = ccEnsembleTrainer.train(train);\n",
    "var ccEnsembleEndTime = System.currentTimeMillis();\n",
    "System.out.println();\n",
    "System.out.println(\"Classifier Chain Ensemble model training took \" + Util.formatDuration(ccEnsembleStartTime,ccEnsembleEndTime));\n",
    "\n",
    "// evaluate the model\n",
    "var ccETStartTime = System.currentTimeMillis();\n",
    "var ccEnsembleEval = eval.evaluate(ccEnsembleModel,test);\n",
    "var ccETEndTime = System.currentTimeMillis();\n",
    "System.out.println(\"Classifier Chain Ensemble model evaluation took \" + Util.formatDuration(ccETStartTime,ccETEndTime));\n",
    "System.out.println(ccEnsembleEval);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As expected the classifier chain ensemble outperformed the binary relevance model and the single classifier chain when using trees as the base learner, at the cost of the greatest runtime. It did this by significantly decreasing the number of false positives, at the cost of a small increase in false negatives. We didn't quite beat the performance of the linear model in terms of Jaccard score, but in general classifier chains are a powerful multi-label approach, and we could always use the linear model as a base learner (and if you do, then you do improve the Jaccard score above 0.497). We leave the implementation of that as an exercise for the reader.\n",
    "\n",
    "## Conclusion\n",
    "\n",
    "We looked at Tribuo's multi-label classification package, trying out several different models using different approaches to the multi-label problem, namely binary relevance models and classifier chains. We're interested in expanding Tribuo's support for multi-label problems, so if there are algorithms or metrics Tribuo is missing head over to our [Github page](https://github.com/oracle/tribuo) and contributions are always welcome."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "12+33"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
